import igraph
import numpy as np
import torch
import pickle
import json

node_type_embeddings = pickle.load(open(
    "../D_VAE/data/kairos_ontology_embeddings.pkl", "rb"))

def _load_ontology_dict():
    saved_dict = pickle.load(open("../D_VAE/data/kairos_ontology.pkl", "rb"))[0]
    event_types = saved_dict["event_types"]
    event_entity_relations = saved_dict["event_entity_relations"]
    entity_types_dict = saved_dict["entity_types_dict"]
    entity_relation_dict = saved_dict["entity_relation_dict"]

    node_type_dict = {}
    for key, val in event_types.items():
        node_type_dict[key] = val
    for key, val in entity_types_dict.items():
        node_type_dict[key] = val + 67

    edge_type_dict = {"event-event-before-after": 0, "no-edge": 1,
        "identical-edge": 2}
    for key, val in event_entity_relations.items():
        edge_type_dict[key] = val + 3
    for key, val in entity_relation_dict.items():
        edge_type_dict[key] = val + 3 + 85

    return node_type_dict, edge_type_dict


node_type_ontology, edge_type_ontology = _load_ontology_dict()


def filter_relations(dataset):
    [all_graph_events, all_graph_event_entity_relations,
        all_graph_event_links, all_graph_entities,
        all_graph_entity_relations] = dataset
    all_pruned_graph_event_entity_relations = []
    all_pruned_graph_entity_relations = []
    all_pruned_graph_events = []
    all_pruned_graph_entities = []
    for ii in range(len(all_graph_event_links)):
        curr_graph_pruned_event_links = all_graph_event_links[ii]
        curr_graph_event_entity_relations = all_graph_event_entity_relations[ii]
        curr_graph_entity_relations = all_graph_entity_relations[ii]
        curr_graph_mentioned_events = set()
        for jj in range(len(curr_graph_pruned_event_links)):
            curr_graph_mentioned_events.add(curr_graph_pruned_event_links[jj][0])
            curr_graph_mentioned_events.add(curr_graph_pruned_event_links[jj][1])
        curr_graph_pruned_events = []
        curr_graph_events = all_graph_events[ii]
        for jj in range(len(curr_graph_events)):
            if curr_graph_events[jj][0] in curr_graph_mentioned_events:
                curr_graph_pruned_events.append(curr_graph_events[jj])
        curr_mentioned_entities = set()
        curr_graph_pruned_event_entity_relations = []
        for jj in range(len(curr_graph_event_entity_relations)):
            temp_event = curr_graph_event_entity_relations[jj][0]
            if temp_event not in curr_graph_mentioned_events:
                continue
            curr_mentioned_entities.add(curr_graph_event_entity_relations[jj][2])
            curr_graph_pruned_event_entity_relations.append(
                curr_graph_event_entity_relations[jj])
        curr_graph_pruned_entity_relations = []
        for jj in range(len(curr_graph_entity_relations)):
            temp_entity_start = curr_graph_entity_relations[jj][0]
            temp_entity_end = curr_graph_entity_relations[jj][2]
            if (temp_entity_start not in curr_mentioned_entities or
                temp_entity_end not in curr_mentioned_entities):
                continue
            curr_graph_pruned_entity_relations.append(curr_graph_entity_relations[jj])
        curr_graph_pruned_entities = []
        curr_graph_entities = all_graph_entities[ii]
        for jj in range(len(curr_graph_entities)):
            if curr_graph_entities[jj][0] in curr_mentioned_entities:
                curr_graph_pruned_entities.append(curr_graph_entities[jj])
        all_pruned_graph_event_entity_relations.append(
            curr_graph_pruned_event_entity_relations)
        all_pruned_graph_entity_relations.append(curr_graph_pruned_entity_relations)
        all_pruned_graph_events.append(curr_graph_pruned_events)
        all_pruned_graph_entities.append(curr_graph_pruned_entities)
    return [all_pruned_graph_events, all_pruned_graph_event_entity_relations,
        all_graph_event_links, all_pruned_graph_entities,
        all_pruned_graph_entity_relations]





def _construct_graph_init(events, entities, event_entity_relas, event_links):
    num_nodes = len(events) + len(event_entity_relas)
    g = igraph.Graph(directed=False)
    g.add_vertices(num_nodes)
    node_dict = {}
    x_features = []
    A_init = np.zeros((num_nodes, num_nodes, len(edge_type_ontology)))
    A_init[:, :, 1] = 1
    entity_type_dict = {}
    for ii in range(len(entities)):
        entity_type_dict[entities[ii][0]] = entities[ii][1]
    for ii in range(len(entities)):
        entity_type_dict[entities[ii][0]] = entities[ii][1]
    for ii in range(len(events)):
        event_id, event_type = events[ii]
        node_dict[event_id] = set([ii])
        temp_feature = node_type_embeddings[node_type_ontology[event_type]]
        x_features.append(temp_feature)
    for ii in range(len(event_links)):
        start_event, end_event = event_links[ii]
        start_id = list(node_dict[start_event])[0]
        end_id = list(node_dict[end_event])[0]
        g.add_edge(start_id, end_id)
        A_init[start_id, end_id, 1] = 0
        A_init[start_id, end_id, 0] = 1
        A_init[end_id, start_id, 1] = 0
        A_init[end_id, start_id, 0] = 1

    for ii in range(len(event_entity_relas)):
        event_id, rela_type, entity_id = event_entity_relas[ii]
        if entity_id in node_dict:
            node_dict[entity_id].add(ii + len(events))
        else:
            node_dict[entity_id] = set([ii + len(events)])
        entity_type = entity_type_dict[entity_id]
        temp_feature = node_type_embeddings[node_type_ontology[entity_type]]
        x_features.append(temp_feature)
        A_init[list(node_dict[event_id])[0],
            ii + len(events), 1] = 0
        A_init[list(node_dict[event_id])[0], ii + len(
            events), edge_type_ontology[rela_type]] = 1
        A_init[ii + len(events), list(node_dict[event_id])[0], 1] = 0
        A_init[ii + len(events), list(node_dict[event_id])[0],
            edge_type_ontology[rela_type]] = 1
        g.add_edge(list(node_dict[event_id])[0], ii + len(events))
    return A_init, node_dict, g, x_features, entity_type_dict



def _construct_A_true(A_true, node_dict, g, entity_relas, entities_dict):
    for key, val in node_dict.items():
        if len(val) <= 1:
            continue
        if key not in entities_dict:
            raise
        val = list(val)
        for ii in range(len(val) - 1):
            for jj in range(ii + 1, len(val)):
                A_true[val[ii], val[jj], 1] = 0
                A_true[val[ii], val[jj], 2] = 1
                A_true[val[jj], val[ii], 1] = 0
                A_true[val[jj], val[ii], 2] = 1
                g.add_edge(val[ii], val[jj])
    seen_pairs = set()
    for ii in range(len(entity_relas)):
        start_entity, entity_rela_type, end_entity = entity_relas[ii]
        if start_entity == end_entity:
            continue
        if start_entity < end_entity:
            if start_entity + end_entity in seen_pairs:
                continue
            seen_pairs.add(start_entity + end_entity)
        else:
            if end_entity + start_entity in seen_pairs:
                continue
            seen_pairs.add(end_entity + start_entity)
        if start_entity not in set(entities_dict) or end_entity not in set(
            entities_dict):
            raise
        if start_entity not in node_dict or end_entity not in node_dict:
            continue
        start_ids = list(node_dict[start_entity])
        end_ids = list(node_dict[end_entity])
        for jj in range(len(start_ids)):
            for kk in range(len(end_ids)):
                type_id = edge_type_ontology[entity_rela_type]
                A_true[start_ids[jj], end_ids[kk], 1] = 0
                A_true[start_ids[jj], end_ids[kk], type_id] = 1
                A_true[end_ids[kk], start_ids[jj], 1] = 0
                A_true[end_ids[kk], start_ids[jj], type_id] = 1
                g.add_edge(start_ids[jj], end_ids[kk])
    return A_true, node_dict, g, entity_relas





def _convert_A_to_dense(matrix):
    ret_mat = np.zeros((len(matrix), len(matrix)))
    for ii in range(len(matrix)):
        for jj in range(len(matrix[ii])):
            if len(np.nonzero(matrix[ii][jj])[0]) != 1:
                print(np.nonzero(matrix[ii][jj]))
                raise
            ret_mat[ii, jj] = int(np.nonzero(matrix[ii][jj])[0])
    return ret_mat





def _construct_graph(pruned_file):
    dataset = pickle.load(open(pruned_file, "rb"))
    [all_graph_events, all_graph_event_entity_relations,
        all_graph_event_links, all_graph_entities,
        all_graph_entity_relations] = filter_relations(dataset)
    all_A_init = []
    all_A_true = []
    all_x_features = []
    for ii in range(len(all_graph_events)):
        print("iter" + str(ii))
        events = all_graph_events[ii]
        event_entity_relas = all_graph_event_entity_relations[ii]
        event_links = all_graph_event_links[ii]
        entities = all_graph_entities[ii]
        entity_relas = all_graph_entity_relations[ii]

        A_init, node_dict, g, X_features, entity_type_dict = _construct_graph_init(
            events, entities, event_entity_relas, event_links)


        A_true = np.copy(A_init)
        
        A_true, node_dict, g, entity_relas = _construct_A_true(
            A_true, node_dict, g, entity_relas, entity_type_dict)

        # if len(np.nonzero(A_true - A_init)[0]) == 0:
        #     print("equal")
        #     raise
        if np.sum(A_true) != A_true.shape[0] * A_true.shape[1]:
            print("true not qeual")
            raise
        if np.sum(A_init) != A_init.shape[0] * A_init.shape[1]:
            print("init not qeual")
            raise

        A_init_dense = _convert_A_to_dense(A_init)
        A_true_dense = _convert_A_to_dense(A_true)

        all_A_init.append(A_init_dense)
        all_A_true.append(A_true_dense)
        all_x_features.append(np.array(X_features))
    return all_A_init, all_A_true, all_x_features















def read_dataset():
    scenarios = ["wiki_ied_bombings_", "suicide_ied_", "wiki_drone_strikes_",
        "wiki_mass_car_bombings_"]
    all_train_A_init = []
    all_train_A_true = []
    all_train_x_features = []
    all_dev_A_init = []
    all_dev_A_true = []
    all_dev_x_features = []
    all_test_A_init = []
    all_test_A_true = []
    all_test_x_features = []
    for scenario in scenarios:
        train_A_init, train_A_true, train_x_features = _construct_graph(
            '../D_VAE/data/Wiki_IED_split/train/' + scenario +
                'train_pruned_new_no_iso_max_50_dataset.pkl')
        dev_A_init, dev_A_true, dev_x_features = _construct_graph(
            '../D_VAE/data/Wiki_IED_split/dev/' + scenario +
                'dev_pruned_new_no_iso_max_50_dataset.pkl')
        test_A_init, test_A_true, test_x_features = _construct_graph(
            '../D_VAE/data/Wiki_IED_split/test/' + scenario +
                'test_pruned_new_no_iso_max_50_dataset.pkl')
        all_train_A_init += train_A_init
        all_train_A_true += train_A_true
        all_train_x_features += train_x_features
        all_dev_A_init += dev_A_init
        all_dev_A_true += dev_A_true
        all_dev_x_features += dev_x_features
        all_test_A_init += test_A_init
        all_test_A_true += test_A_true
        all_test_x_features += test_x_features
    print(len(all_train_A_init))
    print(len(all_dev_A_init))
    print(len(all_test_A_init))
    with open('./data/train_pruned_with_bert_max_50_set.pkl', 'wb') as handle:
        pickle.dump([all_train_A_init, all_train_A_true, all_train_x_features], handle)
    with open('./data/dev_pruned_with_bert_max_50_set.pkl', 'wb') as handle:
        pickle.dump([all_dev_A_init, all_dev_A_true, all_dev_x_features], handle)
    with open('./data/test_pruned_with_bert_max_50_set.pkl', 'wb') as handle:
        pickle.dump([all_test_A_init, all_test_A_true, all_test_x_features], handle)






read_dataset()







